#ifndef __NN_OPTIMIZER_H__
#define __NN_OPTIMIZER_H__
#include "header.hpp"
#include "model.hpp"
#include "tensor.hpp"

using nn::Tensorlist;

namespace nn {

class OptimizerBase {

private:
  virtual Darray _compute(Darray flatten_grads) = 0;
  double _lr; //_learn_rate;
  double _wd; // _weight_decay;
public:
  OptimizerBase(double lr, double decay) : _lr(lr), _wd(decay){};
  ~OptimizerBase(){};

  Darray compute_step(Tensorlist *grads) { Darray flatten_grads; };
};

class Adam : public OptimizerBase {
private:
  size_t _t;
  double _m;
  double _v;
  double _b1;
  double _b2;
  double _eps;
  Darray _compute(Darray flatten_grads) {
    _t++;
    ;
  }

public:
  Adam(double lr = 0.001, double decay = 0.0, double beta1 = 0.8,
       double beta2 = 0.7, double epsilon = 1e-8)
      : OptimizerBase(lr, decay), _b1(beta1), _b2(beta2), _eps(epsilon) {
    _t = 0;
    _m = 0.0;
    _v = 0.0;
  };
  ~Adam(){};
};

} // namespace nn

#endif